import os
import torch
import torch.nn as nn
from einops import rearrange
from .vits_cmae import create_vit
import torch.nn.functional as F

class ImageEncoder(nn.Module):
    def __init__(self,
                 model_name: str = "vit_base",
                 text_feat_dim: int = 768,
                 output_dim: int = 768,
                 hidden_dim: int = 2048,
                 pretrained: bool = True,
                 pretrained_pth: str = "./MITM.pth"
                 ):
        super(ImageEncoder, self).__init__()

        self.model_name = model_name
        self.output_dim = output_dim
        self.text_feat_dim = text_feat_dim

        if "vit" in model_name:
            vit_grad_ckpt = False
            vit_ckpt_layer = 0
            image_size = 224

            vit_name = model_name[4:]
            self.model, vision_width = create_vit(
                vit_name, image_size, vit_grad_ckpt, vit_ckpt_layer, 0)

            self.feature_dim = vision_width

            checkpoint = torch.hub.load_state_dict_from_url(
                url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
                map_location="cpu", check_hash=True)
            #checkpoint = torch.load(pretrained_pth)
            state_dict = checkpoint["model"]
            #del state_dict['patch_embed.proj.weight']
            msg = self.model.load_state_dict(state_dict, strict=False)

    def vit_forward(self, x1, x2):
        return self.model(x1, x2, register_blk=11)

    def forward(self, x1, x2, get_local=False):
        img_feat = self.vit_forward(x1, x2)
        return img_feat[:, 1:].contiguous()
